name: resnet-app-storage
workdir: ~/Downloads/tpu

resources:
  cloud: aws
  instance_type: p3.2xlarge

inputs: {
  gs://cloud-tpu-test-dataset/fake_imagenet: 70,
}

outputs: {
  resnet-model-dir: 0.1,
}

file_mounts:
  /tmp/imagenet:
    source: s3://imagenet-bucket
    mode: MOUNT

setup: |
  . $(conda info --base)/etc/profile.d/conda.sh
  pip install --upgrade pip

  conda activate resnet

  if [ $? -eq 0 ]; then
    echo "conda env exists"
  else
    conda create -n resnet python=3.7 -y
    conda activate resnet
    conda install cudatoolkit=11.0 -y
    pip install tensorflow==2.4.0 pyyaml
  
    # Automatically set CUDNN envvars when conda activate is run
    mkdir -p $CONDA_PREFIX/etc/conda/activate.d
    echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
    echo 'export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$CUDNN_PATH/lib:$LD_LIBRARY_PATH' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh

    cd models
    pip install -e .
  fi

run: |
  . $(conda info --base)/etc/profile.d/conda.sh
  conda activate resnet

  export XLA_FLAGS='--xla_gpu_cuda_data_dir=/usr/local/cuda/'
  python -u models/official/resnet/resnet_main.py --use_tpu=False \
      --mode=train --train_batch_size=256 --train_steps=250 \
      --iterations_per_loop=125 \
      --data_dir=/tmp/imagenet \
      --model_dir=resnet-model-dir \
      --amp --xla --loss_scale=128
